import os

os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
import sys

cwd = os.getcwd()
sys.path.append(cwd.replace('/interface', ''))
print(sys.path)
from evaluate.evaluate_distrib_rl import visualize_exp_variance_by_location, \
    compute_action_output_with_features, visualize_uncertainty_by_location, visualize_distribution_plot, \
    plot_game_quantiles_values, generate_game_plot
from generic.model_util import get_distrib_q_model_save_path

from agent import SportsAgent
from generic.data_util import load_config, read_args, ICEHOCKEY_ACTIONS, divide_dataset_according2date


def test(args):
    episode_num = 'testing'
    test_num_tau = 256
    test_num_supp = 256
    test_gamma = 1
    test_train_rate = 0.8
    max_trace_length = 3
    apply_dynamic_trace_length = False
    test_apply_rnn = True
    test_apply_resnet = True
    test_cut_at_goal = True
    action_only = False
    gda_fitting_target = "QValues"
    load_dqn_episode = 10000  # 'best'  # 10000
    date = 'Nov-19-2021'  # 'Oct-31-2021', 'Nov-19-2021', 'Oct-29-2021', 'Nov-04-2021'
    # best_label = 'best'
    debug_msg = ''
    sanity_check_msg = None

    config, debug_mode, log_file_path = load_config(args)
    if log_file_path is not None:
        log_file = open(log_file_path, 'w')
    else:
        log_file = None
    config['general']['model']['apply_rnn'] = test_apply_rnn
    config['general']['model']['apply_resnet'] = test_apply_resnet
    config['general']['model']['num_tau'] = test_num_tau
    config['general']['model']['num_supp'] = test_num_supp
    config['general']['training']['gamma'] = test_gamma
    config['general']['training']['cut_at_goal'] = test_cut_at_goal
    config['general']['training']['train_rate'] = test_train_rate
    config['general']['model']['apply_dynamic_trace_length'] = apply_dynamic_trace_length
    config['general']['model']['max_trace_length'] = max_trace_length
    config['general']['use_cuda'] = False

    # print('-' * 100, file=log_file, flush=True)
    # print("*** Warning: Launching the sanity check. ***", file=log_file, flush=True)
    # config['general']['model']['input_dim'] = len(ACTIONS) + 4
    # sanity_check_msg = 'sanity_check_location_ha_'  # sanity_check_location_ha_, sanity_check_sd_md_tr_ha_
    # debug_msg = sanity_check_msg + debug_msg
    # print('-' * 100, file=log_file, flush=True)

    agent = SportsAgent(config=config, log_file=log_file)
    model_save_mother_dir = get_distrib_q_model_save_path(agent=agent,
                                                          date_label=date,
                                                          debug_msg=debug_msg)
    # model_save_mother_dir += best_label
    dqn_load_from_path = model_save_mother_dir + '/saved_model_{0}'.format(load_dqn_episode)
    _, episode_num, _, _, _ = agent.load_pretrained_model(load_from=dqn_load_from_path,
                                                          load_optim=False,
                                                          log_file=log_file)

    # model_label = debug_msg + 'distribution' + model_save_mother_dir.split('model')[-1] + '_{0}'.format(load_dqn_episode)
    # visualize_distribution_plot(agent,
    #                             is_home=1,
    #                             target_action='shot',
    #                             debug_mode=debug_mode,
    #                             model_label=model_label,
    #                             sanity_check_msg=sanity_check_msg)
    #
    # agent.fit_gda(
    #     gda_fitting_target=gda_fitting_target,
    #     debug_mode=debug_mode,
    #     sanity_check_msg=sanity_check_msg,
    #     log_file=log_file,
    # )
    # visualize_uncertainty_by_location(agent=agent,
    #                                   debug_mode=debug_mode,
    #                                   sanity_check_msg=sanity_check_msg,
    #                                   episode_num=episode_num,
    #                                   mode='test',
    #                                   model_label=debug_msg + '_heatmap' +
    #                                               model_save_path.split('model')[-1],
    #                                   log_file=log_file)
    # action_samples_dict = {}
    # for i in range(len(testing_files)):
    #     game_name = testing_files[i]
    #     action_samples_dict = \
    #         compute_action_output_with_features(agent=agent,
    #                                             game_name=game_name,
    #                                             is_home=True,
    #                                             action_samples_dict=action_samples_dict,
    #                                             sanity_check_msg=sanity_check_msg,
    #                                             output_type='Uncertainty')

    # plot_game_quantiles_values(agent=agent,
    #                            game_name=game_name,
    #                            sanity_check_msg=sanity_check_msg,
    #                            is_home=True,
    #                            action_samples_dict={},
    #                            action_only=action_only)

    # plot_save_path = './figures_for_evaluation/{0}/game-{1}{3}{4}-tau{5}-supp{6}-gamma{7}-temporal-projection-{2}'.format(
    #     'distrib_dqn' if agent.task == 'train_distrib_rl' else 'dqn',
    #     game_name,
    #     episode_num,
    #     '-rnn' if agent.apply_rnn else '',
    #     '-resnet' if agent.apply_resnet else '',
    #     test_num_tau,
    #     test_num_supp,
    #     test_gamma
    # )
    # all_files = sorted(os.listdir(agent.train_data_path))
    # training_files, validation_files, testing_files, split_dates = \
    #     divide_dataset_according2date(all_data_files=all_files,
    #                                   train_rate=agent.train_rate,
    #                                   sports=agent.sports,
    #                                   if_split=agent.apply_data_date_div,
    #                                   if_return_split=True
    #                                   )
    # for game_name in testing_files:
    game_name = '17080'
    generate_game_plot(agent=agent,
                       date_label=date,
                       sanity_check_msg=sanity_check_msg,
                       game_name=game_name,
                       episode_num='testing',
                       plot_save_path=None,
                       alpha='mean',
                       debug_msg=debug_msg)

    # compute_action_variance_all_game(agent)
    #
    # visualize_exp_variance_by_location(agent=agent,
    #                                    debug_mode=debug_mode,
    #                                    episode_num=episode_num,
    #                                    mode='test',
    #                                    model_label=debug_msg +
    #                                                '_heatmap' +
    #                                                model_save_path.split('model')[-1]
    #                                    )
    # save_calibration_file_path = './calibration_results_record/{0}calibration_results'.format(debug_msg) + \
    #                              model_save_path.split('saved')[-1] + '{0}.txt'.format('_debug' if debug_mode else '')
    # with open(save_calibration_file_path, 'w') as record_file:
    #     contextualized_empirical_risk_measure(agent=agent,
    #                                           # model_label=debug_msg + model_save_path.split('model')[-1],
    #                                           record_file=record_file,
    #                                           sanity_check_msg=sanity_check_msg,
    #                                           debug_mode=debug_mode,
    #                                           verbose=True)


if __name__ == "__main__":
    args = read_args()
    test(args)
